#!/usr/bin/env python3

import copy
import rospy
import threading
import numpy as np
from interbotix_xs_msgs.msg import LocobotJoy
from interbotix_common_modules import angle_manipulation as ang
from interbotix_xs_modules.locobot import InterbotixLocobotXS

class LocobotRobot(object):
    def __init__(self):
        self.waist_step = 0.06
        self.current_loop_rate = 25
        self.loop_rates = {"course" : 25, "fine" : 25}
        self.joy_msg = LocobotJoy()
        self.joy_mutex = threading.Lock()
        self.rate = rospy.Rate(self.current_loop_rate)
        self.use_base = rospy.get_param("~use_base")
        robot_model = rospy.get_param("~robot_model")
        self.base_type = rospy.get_param("~base_type")
        robot_name = rospy.get_namespace().strip("/")
        self.arm_model = "mobile_" + robot_model.split("_")[1]
        if self.arm_model == "mobile_base": self.arm_model = None
        self.locobot = InterbotixLocobotXS(
            robot_model=robot_model,
            arm_model=self.arm_model,
            robot_name=robot_name,
            init_node=False
        )
        if self.arm_model is not None:
            self.rotate_step = 0.04
            self.translate_step = 0.01
            self.gripper_pressure_step = 0.125
            self.current_gripper_pressure = 0.5
            self.num_joints = self.locobot.arm.group_info.num_joints
            self.waist_index = self.locobot.arm.group_info.joint_names.index("waist")
            self.waist_ll = self.locobot.arm.group_info.joint_lower_limits[self.waist_index]
            self.waist_ul = self.locobot.arm.group_info.joint_upper_limits[self.waist_index]
            self.T_sy = np.identity(4)
            self.T_yb = np.identity(4)
            self.update_T_yb()
        rospy.Subscriber("commands/joy_processed", LocobotJoy, self.joy_control_cb)

    ### @brief Helper function that updates the frequency at which the main control loop runs
    ### @param loop_rate - desired loop frequency [Hz]
    def update_speed(self, loop_rate):
        self.current_loop_rate = loop_rate
        self.rate = rospy.Rate(self.current_loop_rate)
        rospy.loginfo("Current loop rate is %d Hz." % self.current_loop_rate)

    def update_T_yb(self):
        T_sb = self.locobot.arm.get_ee_pose_command()
        rpy = ang.rotationMatrixToEulerAngles(T_sb[:3, :3])
        self.T_sy[:2,:2] = ang.yawToRotationMatrix(rpy[2])
        self.T_yb = np.dot(ang.transInv(self.T_sy), T_sb)

    def update_gripper_pressure(self, gripper_pressure):
        self.current_gripper_pressure = gripper_pressure
        self.locobot.gripper.set_pressure(self.current_gripper_pressure)
        rospy.loginfo("Gripper pressure is at %.2f%%." % (self.current_gripper_pressure * 100.0))

    def joy_control_cb(self, msg):
        with self.joy_mutex:
            self.joy_msg = copy.deepcopy(msg)

        # Check the speed_cmd
        if (msg.speed_cmd == LocobotJoy.SPEED_INC and self.current_loop_rate < 40):
            self.update_speed(self.current_loop_rate + 1)
        elif (msg.speed_cmd == LocobotJoy.SPEED_DEC and self.current_loop_rate > 10):
            self.update_speed(self.current_loop_rate - 1)

        # Check the speed_toggle_cmd
        if (msg.speed_toggle_cmd == LocobotJoy.SPEED_COURSE):
            self.loop_rates["fine"] = self.current_loop_rate
            rospy.loginfo("Switched to Course Control")
            self.update_speed(self.loop_rates["course"])
        elif (msg.speed_toggle_cmd == LocobotJoy.SPEED_FINE):
            self.loop_rates["course"] = self.current_loop_rate
            rospy.loginfo("Switched to Fine Control")
            self.update_speed(self.loop_rates["fine"])

        # check base_reset_odom_cmd
        if (msg.base_reset_odom_cmd == LocobotJoy.RESET_ODOM and self.use_base):
            if self.base_type == 'kobuki':
                self.locobot.base.reset_odom()
            elif self.base_type == 'create3':
                rospy.logwarn("Can't reset odometry using a Create 3 base.")

        if self.arm_model is None:  return

        # Check the gripper_cmd
        if (msg.gripper_cmd == LocobotJoy.GRIPPER_OPEN):
            self.locobot.gripper.open(delay=0)
        elif (msg.gripper_cmd == LocobotJoy.GRIPPER_CLOSE):
            self.locobot.gripper.close(delay=0)

        # Check the gripper_pwm_cmd
        if (msg.gripper_pwm_cmd == LocobotJoy.GRIPPER_PWM_INC and self.current_gripper_pressure < 1):
            self.update_gripper_pressure(self.current_gripper_pressure + self.gripper_pressure_step)
        elif (msg.gripper_pwm_cmd == LocobotJoy.GRIPPER_PWM_DEC and self.current_gripper_pressure > 0):
            self.update_gripper_pressure(self.current_gripper_pressure - self.gripper_pressure_step)

    def controller(self):

        with self.joy_mutex:
            msg = copy.deepcopy(self.joy_msg)

        # check if the pan-and-tilt mechanism should be reset
        if (msg.pan_cmd == LocobotJoy.PAN_TILT_HOME and
            msg.tilt_cmd == LocobotJoy.PAN_TILT_HOME):
            self.locobot.camera.pan_tilt_go_home(1.0, 0.5, 1.0, 0.5, False)

        # check if the pan/tilt mechanism should be rotated
        elif (msg.pan_cmd != 0 or msg.tilt_cmd != 0):
            cam_positions = self.locobot.camera.get_joint_commands()

            if (msg.pan_cmd == LocobotJoy.PAN_CCW):
                cam_positions[0] += self.waist_step
            elif (msg.pan_cmd == LocobotJoy.PAN_CW):
                cam_positions[0] -= self.waist_step

            if (msg.tilt_cmd == LocobotJoy.TILT_UP):
                cam_positions[1] += self.waist_step
            elif (msg.tilt_cmd == LocobotJoy.TILT_DOWN):
                cam_positions[1] -= self.waist_step

            self.locobot.camera.pan_tilt_move(cam_positions[0], cam_positions[1], 0.2, 0.1, 0.2, 0.1, False)

        # check base related commands
        if self.use_base:
            self.locobot.base.command_velocity(msg.base_x_cmd, msg.base_theta_cmd)

        if self.arm_model is None: return

        # Check the pose_cmd
        if (msg.pose_cmd != 0):
            if (msg.pose_cmd == LocobotJoy.HOME_POSE):
                self.locobot.arm.go_to_home_pose(1.5, 0.75)
            elif (msg.pose_cmd == LocobotJoy.SLEEP_POSE):
                self.locobot.arm.go_to_sleep_pose(1.5, 0.75)
            self.update_T_yb()

        # Check the waist_cmd
        if (msg.waist_cmd != 0):
            waist_position = self.locobot.arm.get_single_joint_command("waist")
            if (msg.waist_cmd == LocobotJoy.WAIST_CCW):
                success = self.locobot.arm.set_single_joint_position("waist", waist_position + self.waist_step, 0.2, 0.1, False)
                if (success == False and waist_position != self.waist_ul):
                    self.locobot.arm.set_single_joint_position("waist", self.waist_ul, 0.2, 0.1, False)
            elif (msg.waist_cmd == LocobotJoy.WAIST_CW):
                success = self.locobot.arm.set_single_joint_position("waist", waist_position - self.waist_step, 0.2, 0.1, False)
                if (success == False and waist_position != self.waist_ll):
                    self.locobot.arm.set_single_joint_position("waist", self.waist_ll, 0.2, 0.1, False)
            self.update_T_yb()

        position_changed = msg.ee_x_cmd + msg.ee_z_cmd
        if (self.num_joints >= 6):
            position_changed += msg.ee_y_cmd
        orientation_changed = msg.ee_roll_cmd + msg.ee_pitch_cmd

        if (position_changed + orientation_changed == 0): return

        # Copy the most recent T_yb transform into a temporary variable
        T_yb = np.array(self.T_yb)

        if (position_changed):
            # check ee_x_cmd
            if (msg.ee_x_cmd == LocobotJoy.EE_X_INC):
                T_yb[0, 3] += self.translate_step
            elif (msg.ee_x_cmd == LocobotJoy.EE_X_DEC):
                T_yb[0, 3] -= self.translate_step

            # check ee_y_cmd
            if (msg.ee_y_cmd == LocobotJoy.EE_Y_INC and self.num_joints >= 6 and T_yb[0, 3] > 0.3):
                T_yb[1, 3] += self.translate_step
            elif (msg.ee_y_cmd == LocobotJoy.EE_Y_DEC and self.num_joints >= 6 and T_yb[0, 3] > 0.3):
                T_yb[1, 3] -= self.translate_step

            # check ee_z_cmd
            if (msg.ee_z_cmd == LocobotJoy.EE_Z_INC):
                T_yb[2, 3] += self.translate_step
            elif (msg.ee_z_cmd == LocobotJoy.EE_Z_DEC):
                T_yb[2, 3] -= self.translate_step

        # check end-effector orientation related commands
        if (orientation_changed != 0):
            rpy = ang.rotationMatrixToEulerAngles(T_yb[:3, :3])

            # check ee_roll_cmd
            if (msg.ee_roll_cmd == LocobotJoy.EE_ROLL_CCW):
                rpy[0] += self.rotate_step
            elif (msg.ee_roll_cmd == LocobotJoy.EE_ROLL_CW):
                rpy[0] -= self.rotate_step

            # check ee_pitch_cmd
            if (msg.ee_pitch_cmd == LocobotJoy.EE_PITCH_DOWN):
                rpy[1] += self.rotate_step
            elif (msg.ee_pitch_cmd == LocobotJoy.EE_PITCH_UP):
                rpy[1] -= self.rotate_step

            T_yb[:3,:3] = ang.eulerAnglesToRotationMatrix(rpy)

        # Get desired transformation matrix of the end-effector w.r.t. the base frame
        T_sd = np.dot(self.T_sy, T_yb)
        _, success = self.locobot.arm.set_ee_pose_matrix(T_sd, self.locobot.arm.get_joint_commands(), True, 0.2, 0.1, False)
        if (success):
            self.T_yb = np.array(T_yb)

def main():
    rospy.init_node('xslocobot_robot')
    locobot = LocobotRobot()
    while not rospy.is_shutdown():
        locobot.controller()
        locobot.rate.sleep()

if __name__=='__main__':
    main()
